package com.auditlog.sql.runner;

import cn.hutool.core.collection.CollectionUtil;
import com.auditlog.datasource.ParametersHolder;
import com.auditlog.datasource.StatementProxy;
import com.auditlog.datasource.table.ColumnMeta;
import com.auditlog.sql.util.ConstantsExpressionUtil;
import com.auditlog.sql.dynamic.DynamicValueJudgeRegistry;
import com.auditlog.exception.NotSupportException;
import com.auditlog.util.IndexUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.JdbcParameter;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.util.deparser.ExpressionDeParser;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;


/**
 * <b>1.只支持占位符(?)是独立，不在任何表达式中</b><br></br>
 * <b>2.不要写成insert into t select t1.*,pk1,pk2,t2.* from t1,t2这种形式，这样会多一次查询</b><br></br>
 * <b>&nbsp;&nbsp;2.1可以写成 insert into t select pk1, clomun1, pk2, t2.* from t2这种形式</b><br></br>
 * <b>&nbsp;&nbsp;2.2可以写成 insert into t select t2.*,pk1,clomun1,pk2 from t2这种形式</b><br></br>
 *
 * @author Zhiyang.Zhang
 * @version 1.0
 * @date 2022/11/4 23:48
 */
@Slf4j
public abstract class AbstractInsertSqlRunner<T extends Statement> extends AbstractSqlRunner<T, Insert> {

    protected List<List<Object>> insertSelectColumnValues = null;

    protected List<String> columnsContainsNoAutoIncreasePk = null;

    public AbstractInsertSqlRunner(StatementProxy<T> statementProxy, Insert insert) {
        super(statementProxy, insert);
        List<Column> setColumns = this.getStatement().getSetColumns();
        if (CollectionUtil.isNotEmpty(setColumns)) {
            throw new NotSupportException("不支持insert set解析.");
        }
    }

    public boolean canUseReturnGeneratedKeys() {
        return this.isPrepareStatement() &&
                (
                        // oracle: insert select不能获取，batch也不能获取
                        !this.isBatch() && this.getStatement().getSelect() == null && this.isOracle() ||
                                // mysql:没有设置主键，且自增的
                                !this.insertContainsPk() && this.getTableMeta().isAutoIncreasePk() ||
                                !this.getStatementProxy().getConnectionProxy().getConnectionContext().isPrepareMultiSql() &&
                                        this.getStatement().getSelect() == null && !this.isMultiValue() && !this.isOracle()
                ) ||
                !this.isPrepareStatement() &&
                        // executeUpdate(sql)
                        (!this.isBatch() && this.getStatement().getSelect() == null && !this.isMultiValue());
    }

    public boolean isMultiValue() {
        return this.getStatement().getItemsList() instanceof MultiExpressionList;
    }

    /**
     * 获取所有主键的值，为空表示主键中不是所有的都为常量，需要进行select获取
     *
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> getPkValues() throws SQLException {
        return getColumnValues(this.getColumnWithAliasIndexes(this.getTableMeta().getPkNameListWithAlias(this.getAlias())));
    }

    /**
     * 是否是符合主键，且只未包含自增的列
     *
     * @return: boolean
     */
    public boolean insertMultiPkOnlyWithoutAutoIncrease() {
        Map<String, ColumnMeta> primaryKeyMap = this.getTableMeta().getPrimaryKeyMap();
        if (this.columnsContainsNoAutoIncreasePk == null) {
            if (primaryKeyMap.size() <= 1) {
                return false;
            }
            List<String> autoIncreaseColumns = new ArrayList<>();
            List<String> notAutoIncreaseColumns = new ArrayList<>();
            primaryKeyMap.forEach((k, v) -> {
                if (v.isAutoincrement()) {
                    autoIncreaseColumns.add(k);
                } else {
                    notAutoIncreaseColumns.add(k);
                }
            });
            this.columnsContainsNoAutoIncreasePk = notAutoIncreaseColumns;
            if (autoIncreaseColumns.size() == 1) {
                return true;
            }
            return false;
        } else {
            return primaryKeyMap.size() - this.columnsContainsNoAutoIncreasePk.size() == 1;
        }
    }

    /**
     * 获取指定列的值
     *
     * @param columnIndex
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> getColumnValues(List<Integer> columnIndex) throws SQLException {
        List<List<Object>> columnValuesList = new ArrayList<>();
        List<List<Expression>> columnValueExpressions = getColumnValueExpressions(columnIndex);
        // 保证全部是常量之后才进行后面的计算
        boolean matched = columnValueExpressions.stream().allMatch(expressionList ->
                expressionList.stream().allMatch(expression -> ConstantsExpressionUtil.isConstant(expression)));
        if (matched && columnValueExpressions.size() > 0) {
            Insert insert = this.getStatement();
            ItemsList itemsList = insert.getItemsList();
            if (itemsList != null && itemsList instanceof MultiExpressionList) {
                // 如果是union中含有占位符，也应该走这里
                List<Expression> multiExpressionList = new ArrayList<>();
                // 放入一个list为了获取占位符的值
                columnValueExpressions.forEach(expressionList -> multiExpressionList.addAll(expressionList));
                // 一个List<Object>含有多行数据(values (),...,()的情况)
                List<List<Object>> constantExpressionValue = this.getConstantExpressionValue(multiExpressionList);
                List<Object> mergeList = new ArrayList<>();
                constantExpressionValue.forEach(list -> mergeList.addAll(list));
                int size = columnIndex.size();
                // 按照主键的列数进行拆分
                columnValuesList.addAll(splitList(mergeList, size));
            } else {
                // 这里是values后面只有一个（），或者是select的情况和select union的情况
                for (List<Expression> expressionList : columnValueExpressions) {
                    // 一个List<Object>对应一行数据
                    List<List<Object>> constantExpressionValue = this.getConstantExpressionValue(expressionList);
                    columnValuesList.addAll(constantExpressionValue);
                }
            }
        }
        return columnValuesList;
    }


    /**
     * 判断主键中是否包含动态值（比如：sequence等）
     *
     * @return: boolean
     */
    public boolean pkContainsDynamicValue() {
        if (insertContainsPk()) {
            List<List<Expression>> pkValuesExpression = getPkColumnValueExpressions();
            if (CollectionUtil.isEmpty(pkValuesExpression) && this.insertContainsPk()) {
                // 如果不能明确的解析出所有的pk列，那么就比较所有列
                List<List<Expression>> columnValueExpression = getAllInsertColumnValueExpression();
                return this.expressionsContainsDynamicValue(columnValueExpression);
            } else {
                // 解析出了所有的pk列
                this.expressionsContainsDynamicValue(pkValuesExpression);
            }
        }
        return false;
    }

    public boolean expressionsContainsDynamicValue(List<List<Expression>> expressionsList) {
        return expressionsList.stream().anyMatch(expressions ->
                expressions.stream().anyMatch(expression -> DynamicValueJudgeRegistry.getInstance().isDynamicValue(expression, this.getStatementProxy().getConnectionProxy())));
    }

    /**
     * 获取主键列表达式
     *
     * @return: java.util.List<java.util.List < net.sf.jsqlparser.expression.Expression>>
     */
    public List<List<Expression>> getPkColumnValueExpressions() {
        List<Integer> pkIndexes = this.getColumnWithAliasIndexes(this.getTableMeta().getPkNameListWithAlias(this.getAlias()));
        return getColumnValueExpressions(pkIndexes);
    }

    /**
     * 根据插入列的下标获取插入的值
     *
     * @return: 主键列表达式集合
     */
    public List<List<Expression>> getColumnValueExpressions(List<Integer> columnIndexes) {
        List<List<Expression>> columnValuesExpressions = new ArrayList<>();
        Insert insert = this.getStatement();
        if (CollectionUtil.isNotEmpty(columnIndexes)) {
            ItemsList itemsList = insert.getItemsList();
            if (itemsList != null) {
                if (itemsList instanceof ExpressionList) {
                    ExpressionList expressionList = (ExpressionList) itemsList;
                    columnValuesExpressions.add(getExpressionByIndex(expressionList.getExpressions(), columnIndexes));
                } else if (itemsList instanceof MultiExpressionList) {
                    MultiExpressionList multiExpressionList = (MultiExpressionList) itemsList;
                    multiExpressionList.getExpressionLists()
                            .forEach(expressionList -> columnValuesExpressions.add(getExpressionByIndex(expressionList.getExpressions(), columnIndexes)));
                }
            } else {
                // select情况
                Select select = insert.getSelect();
                // 要注意查询全部(*)的情况
                List<List<Expression>> expressionsList = getExpressions(select.getSelectBody());
                Integer maxIndex = IndexUtils.getMaxNumber(columnIndexes);
                List<List<Expression>> filteredExpression = expressionsList.stream().filter(expressionList -> expressionList.size() >= maxIndex)
                        .filter(expressionList -> {
                            for (int i = 0; i < maxIndex; i++) {
                                Expression expression = expressionList.get(i);
                                // 保证键之间无*,过滤（select pk_1,t.* ,pk2）这种
                                if (isSelectAll(expression)) {
                                    return false;
                                }
                            }
                            return true;
                        }).collect(Collectors.toList());
                // 如果不相等，表示需要查询才能得到结果，相等的话只有全部是常数才能确定值（ps:主键不可能相等，指定了相等的主键数据库也会报错）
                if (expressionsList.size() == filteredExpression.size()) {
                    expressionsList.forEach(expressionList -> {
                        columnValuesExpressions.add(getExpressionByIndex(expressionList, columnIndexes));
                    });
                } else {
                    // 反向查找一次  insert into t select t2.*,pk1,clomun1,pk2 from t2 这种形式
                    Integer minIndex = IndexUtils.getMinNumber(columnIndexes);
                    int columnCount = this.extractInsertColumns().size();
                    int validCount = columnCount - minIndex + 1;
                    List<List<Expression>> reverseFilteredExpression = expressionsList.stream().filter(expressionList -> expressionList.size() >= validCount)
                            .filter(expressionList -> {
                                for (int i = expressionList.size() - 1; i >= minIndex - 1; i--) {
                                    Expression expression = expressionList.get(i);
                                    // 保证键之间无*,过滤（select pk_1,t.* ,pk2）这种
                                    if (isSelectAll(expression)) {
                                        return false;
                                    }
                                }
                                return true;
                            }).map(expressionList -> CollectionUtil.reverse(expressionList)).collect(Collectors.toList());
                    if (expressionsList.size() == reverseFilteredExpression.size()) {
                        List<Integer> reversePkIndex = CollectionUtil.reverse(columnIndexes).stream().map(integer -> columnCount - integer + 1).collect(Collectors.toList());
                        reverseFilteredExpression.forEach(expressions -> columnValuesExpressions.add(CollectionUtil.reverse(getExpressionByIndex(expressions, reversePkIndex))));
                    }
                }
            }
        }
        return columnValuesExpressions;
    }

    /**
     * 获取待插入列的表达式
     *
     * @return: java.util.List<java.util.List < net.sf.jsqlparser.expression.Expression>>
     */
    public List<List<Expression>> getAllInsertColumnValueExpression() {
        List<List<Expression>> allColumnValuesExpressions = new ArrayList<>();
        Insert insert = this.getStatement();
        ItemsList itemsList = insert.getItemsList();
        if (itemsList != null) {
            if (itemsList instanceof ExpressionList) {
                ExpressionList expressionList = (ExpressionList) itemsList;
                allColumnValuesExpressions.add(expressionList.getExpressions());
            } else if (itemsList instanceof MultiExpressionList) {
                MultiExpressionList multiExpressionList = (MultiExpressionList) itemsList;
                multiExpressionList.getExpressionLists()
                        .forEach(expressionList -> allColumnValuesExpressions.add((expressionList.getExpressions())));
            }
        } else {
            // select情况
            Select select = insert.getSelect();
            // 要注意查询全部(*)的情况
            allColumnValuesExpressions.addAll(getExpressions(select.getSelectBody()));
        }
        return allColumnValuesExpressions;
    }

    public List<Expression> getExpressionByIndex(List<Expression> expressionList, List<Integer> index) {
        List<Expression> expressions = new ArrayList<>();
        index.forEach(i -> expressions.add(expressionList.get(i - 1)));
        return expressions;
    }


    /**
     * 获取SelectBody中的所有select字段，包括union
     *
     * @param selectBody
     * @return: java.util.List<java.util.List < net.sf.jsqlparser.expression.Expression>>
     */
    public List<List<Expression>> getExpressions(SelectBody selectBody) {
        List<List<Expression>> expressions = new ArrayList<>();
        if (selectBody instanceof PlainSelect) {
            PlainSelect plainSelect = (PlainSelect) selectBody;
            List<SelectItem> selectItems = plainSelect.getSelectItems();
            List<Expression> expressionList = new ArrayList<>();
            selectItems.forEach(selectItem -> {
                if (selectItem instanceof AllTableColumns) { // table.*
                    expressionList.add((AllTableColumns) selectItem);
                } else if (selectItem instanceof AllColumns) { // *
                    expressionList.add((AllColumns) selectItem);
                } else if (selectItem instanceof SelectExpressionItem) { // 具体列
                    expressionList.add(((SelectExpressionItem) selectItem).getExpression());
                }
            });
            expressions.add(expressionList);
        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            List<SelectBody> selects = setOperationList.getSelects();
            selects.forEach(select -> {
                expressions.addAll(getExpressions(select));
            });
        }
        return expressions;
    }


    /**
     * 判断是否是select *
     *
     * @param expression 列表达式
     * @return: boolean
     */
    public boolean isSelectAll(Expression expression) {
        return expression instanceof AllColumns || expression instanceof AllTableColumns;
    }

    /**
     * 通过执行sql的方式获取主键
     *
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> extractPkValues() throws SQLException {
        return extractColumnValues(this.getColumnWithAliasIndexes(this.getTableMeta().getPkNameListWithAlias(this.getAlias())));
    }


    /**
     * 从select的记录中获取指定列的值
     *
     * @param columnIndex
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> extractColumnValues(List<Integer> columnIndex) throws SQLException {
        List<List<Object>> insertValues = this.extractInsertValues();
        return this.getIndexValue4List(insertValues, columnIndex, false);
    }


    /**
     * 获取insert values或insert select后面的值集合
     *
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> extractInsertValues() throws SQLException {
        if (insertSelectColumnValues == null) {
            List<List<Object>> insertValues = new ArrayList<>();
            // value子句
            ItemsList itemsList = this.getStatement().getItemsList();
            // select子句
            Select select = this.getStatement().getSelect();
            // 获取values对应的值集合
            if (itemsList != null) {
                // 有value子句
                if (itemsList instanceof MultiExpressionList) {
                    // values后面跟多个()
                    MultiExpressionList multiExpressionList = (MultiExpressionList) itemsList;
                    String sql = expressions2Sql(multiExpressionList.getExpressionLists()); // sql中有全部的占位符
                    insertValues = this.executeParametersHolderSql(sql, this.multiExpressionListPlaceHolderIndex(multiExpressionList), this.getParameters());
                } else if (itemsList instanceof ExpressionList) {
                    // values后只有一个()
                    ExpressionList expressionList = (ExpressionList) itemsList;
                    String sql = expressions2Sql(expressionList);
                    insertValues = this.executeParametersHolderSql(sql, this.expressionListPlaceHolderIndex(expressionList.getExpressions()), this.getParameters());
                }
            } else {
                String subSql = select.toString();
                if (select.getWithItemsList() != null) {
                    subSql = "select  * from  ( " + subSql + " ) " + this.getTempTableName();
                }
                insertValues = this.executeParametersHolderSql(subSql, this.selectPlaceHolderIndex(select), this.getParameters());
            }
            this.insertSelectColumnValues = insertValues;
        }
        return insertSelectColumnValues;
    }

    public List<List<Object>> splitList(List<Object> list, int size) {
        List<List<Object>> result = new ArrayList<>();
        List objects = new ArrayList<>();
        for (int i = 0; i < list.size(); i++) {
            if (i % size == 0) {
                if (objects.size() > 0) {
                    result.add(objects);
                    objects = new ArrayList();
                }
            }
            objects.add(list.get(i));
        }
        if (objects.size() > 0) {
            result.add(objects);
        }
        return result;
    }

    /**
     * 获取expression集合常量值(占位符和常量值可能同时出现)
     *
     * @param expressions Expression集合
     * @return: java.util.List<java.util.List < java.lang.Object>>
     */
    public List<List<Object>> getConstantExpressionValue(List<Expression> expressions) throws SQLException {
        List<List<Object>> insertValues = new ArrayList<>();
        List<Integer> indexes = this.expressionListPlaceHolderIndex(expressions);
        Map<Integer, ArrayList<Object>> parameters = null;
        int paramsSize = 0;// addBatch了几次
        if (this.getStatementProxy() instanceof ParametersHolder) {
            ParametersHolder parametersHolder = (ParametersHolder) this.getStatementProxy();
            parameters = parametersHolder.getParameters();
            if (parameters.size() > 0) {
                for (Integer paraIndex : parameters.keySet()) {
                    paramsSize = parameters.get(paraIndex).size();
                    break;
                }
            }
        }

        if (indexes.size() == 0) { //没有带占位符的值
            Object[] objects = new Object[expressions.size()];
            for (int i = 0; i < expressions.size(); i++) {
                Expression expression = expressions.get(i);
                objects[i] = ConstantsExpressionUtil.getValue(expression);
            }
            insertValues.add(Arrays.asList(objects));
        } else {
            for (int i = 0; i < paramsSize; i++) {
                Object[] objects = new Object[expressions.size()];
                for (int j = 0; j < expressions.size(); j++) {
                    Expression expression = expressions.get(j);
                    if (expression instanceof JdbcParameter) {
                        List<Integer> index = new ArrayList<>();
                        ExpressionDeParser expressionDeParser = new ExpressionDeParser() {
                            @Override
                            public void visit(JdbcParameter jdbcParameter) {
                                index.add(jdbcParameter.getIndex());
                                super.visit(jdbcParameter);
                            }
                        };
                        expression.accept(expressionDeParser);
                        objects[j] = parameters.get(index.get(0)).get(i);
                    } else {
                        objects[j] = ConstantsExpressionUtil.getValue(expression);
                    }
                }
                insertValues.add(Arrays.asList(objects));
            }
        }
        return insertValues;
    }


    /**
     * 获取待插入的列名
     *
     * @return: java.util.List<java.lang.String>
     */
    public List<String> extractInsertColumns() {
        // 插入的列
        List<Column> columns = this.getStatement().getColumns();
        if (CollectionUtil.isNotEmpty(columns)) {
            return columnsName(columns);
        } else {
            return this.getTableMeta().getColumnNamesByOrder(this.getAlias());
        }
    }


    /**
     * 插入的列中是否包含主键
     *
     * @return: boolean
     */
    public boolean insertContainsPk() {
        List<Integer> pkIndex = this.getColumnWithAliasIndexes(this.getTableMeta().getPkNameListWithAlias(this.getAlias()));
        return CollectionUtil.isNotEmpty(pkIndex);
    }

    public boolean insertContainsAssignedColumns(List<String> columnNames) {
        List<Integer> columnIndexes = this.getColumnWithoutAliasIndexes(columnNames);
        return CollectionUtil.isNotEmpty(columnIndexes);
    }

    /**
     * 获取主键对应的下标
     *
     * @return: java.util.List<java.lang.Integer>
     */
    public List<Integer> getColumnWithAliasIndexes(List<String> columnNames) {
        if (CollectionUtil.isEmpty(columnNames)) {
            return new ArrayList<>();
        }
        List<String> insertColumns = this.extractInsertColumns();
        return IndexUtils.getIndexList(columnNames, insertColumns);
    }

    public List<Integer> getColumnWithoutAliasIndexes(List<String> columnNames) {
        List<String> columnsWithAlias = getColumnsWithAlias(columnNames);
        return getColumnWithAliasIndexes(columnsWithAlias);
    }

    public List<String> getColumnsWithAlias(List<String> columnNames) {
        if (CollectionUtil.isEmpty(columnNames)) {
            return new ArrayList<>();
        }
        List<String> columnsWithAlias = columnNames.stream().map(columnName -> this.getTableMeta().getColumnNameWithAlias(this.getAlias(), columnName, this.getDbType(), true)).collect(Collectors.toList());
        return columnsWithAlias;
    }


    public List<String> columnsName(List<Column> columns) {
        List<String> columnsName = new ArrayList<>();
        if (CollectionUtil.isNotEmpty(columns)) {
            for (Column column : columns) {
                Table table = column.getTable();
                String columnNameWithAlias = this.getTableMeta().getColumnNameWithAlias(table == null ? "" : table.getName(),
                        column.getColumnName(), this.getDbType(), true);
                columnsName.add(columnNameWithAlias);
            }
        }
        return columnsName;
    }
}
